package org.zjvis.datascience.spark.algorithm;

import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.fpm.AssociationRules;
import org.apache.spark.mllib.fpm.FPGrowth;
import org.apache.spark.mllib.fpm.FPGrowthModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.zjvis.datascience.spark.util.OptionHelper;
import org.zjvis.datascience.spark.util.OutputResult;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.Tuple3;
import scala.collection.JavaConversions;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @description Spark-FP-growth算子
 * @date 2021-12-23
 */
public class FPGrowthAlgorithm extends BaseAlgorithm {

    private String[] featureCols;

    private String sourceTable;

    private String targetTable;

    private String targetTable2;

    private double minSupport;

    private double minConfidence;

    private int minFreq;


    public FPGrowthAlgorithm(SparkSession sparkSession) {
        super(sparkSession);
    }


    public boolean parseParams(String[] args) {
        super.parseParams(args);

        options.addOption(OptionHelper.getSpecOption("source"));
        options.addOption(OptionHelper.getSpecOption("target"));
        options.addOption(OptionHelper.getSpecOption("featureCols"));
        options.addOption(OptionHelper.getSpecOption("mins", "minSupport", "min support", true));
        options.addOption(OptionHelper.getSpecOption("minc", "minConf", "min confidence", true));
        options.addOption(OptionHelper.getSpecOption("minf", "minFreq", "min frequency", true));
        options.addOption(OptionHelper.getSpecOption("t2", "target2", "target table 2", true));

        boolean flag = this.initCommandLine(args);

        if (!flag) {
            logger.error("initCommandLine fail!!!");
            return false;
        }
        if (cmd.hasOption("f") || cmd.hasOption("featureCols")) {
            featureCols = cmd.getOptionValue("featureCols").split(",");
        }
        if (cmd.hasOption("s") || cmd.hasOption("source")) {
            sourceTable = cmd.getOptionValue("source");
        }
        if (cmd.hasOption("t") || cmd.hasOption("target")) {
            targetTable = cmd.getOptionValue("target");
        }
        if (cmd.hasOption("mins") || cmd.hasOption("minSupport")) {
            minSupport = Double.parseDouble(cmd.getOptionValue("minSupport"));
        }
        if (cmd.hasOption("minf") || cmd.hasOption("minFreq")) {
            minFreq = Integer.parseInt(cmd.getOptionValue("minFreq"));
        }
        if (cmd.hasOption("minc") || cmd.hasOption("minConf")) {
            minConfidence = Double.parseDouble(cmd.getOptionValue("minConf"));
        }
        if (cmd.hasOption("t2") || cmd.hasOption("target2")) {
            targetTable2 = cmd.getOptionValue("target2");
        }
        return true;
    }

    public boolean beginAlgorithm() {
        Dataset<Row> dataset;
        if (isHive) {
            String sql = UtilTool.buildSelectSql(featureCols, sourceTable, "");
            dataset = sparkSession.sql(sql);
        } else {
            if (isSample) {
                String cacheTable = cacheUtil.modifyCacheTableName(sourceTable);
                if (cacheUtil.isCacheTableExists(cacheTable)) {
                    dataset = sparkSession.table(cacheTable)
                            .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, "")));
                } else {
                    Dataset<Row> cacheRdd = UtilTool.readFromGreenPlum(sparkSession, sourceTable, "", sampleNumber);
                    dataset = cacheRdd.select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, "")));
                    if (!cacheUtil.cacheTableForDataset(cacheRdd, cacheTable)) {
                        logger.error("table = {} cache fail!!!", cacheTable);
                        return false;
                    }
                    logger.info("table = {} cache success", cacheTable);
                }
            } else {
                dataset = UtilTool.readFromGreenPlum(sparkSession, sourceTable, "")
                        .select(JavaConversions.asScalaBuffer(UtilTool.selectColumns(featureCols, "")));
            }
        }

        int length = featureCols.length;

        JavaRDD<List<String>> tansactions = dataset.toJavaRDD().map(row -> {
            List<String> result = new ArrayList<>();
            for (int i = 0; i < length; ++i) {
                String s = row.get(i) == null ? "" : row.get(i).toString();
                if (StringUtils.isNotEmpty(s) && !result.contains(s)) {
                    result.add(s);
                }
            }
            return result;
        }).persist(StorageLevel.MEMORY_AND_DISK());

        FPGrowth fpg = new FPGrowth()
                .setMinSupport(minSupport);

        FPGrowthModel model = fpg.run(tansactions);
        //TODO
        // 1. save model?
        // 2. save pattern and frequency table?
        long totalNum = tansactions.count();
        JavaRDD<FPGrowth.FreqItemset> patFreq = model.freqItemsets().toJavaRDD();
        JavaRDD<Row> patternRdd = patFreq.mapPartitions(iter -> {
            ArrayList<Tuple3<String, Long, Double>> result = new ArrayList<>();
            while (iter.hasNext()) {
                FPGrowth.FreqItemset freqItemset = iter.next();
                if (freqItemset.freq() >= minFreq) {
                    result.add(new Tuple3<>(freqItemset.javaItems().toString(), freqItemset.freq(), freqItemset.freq() * 1.0 / totalNum));
                }
            }
            return result.iterator();
        }).zipWithUniqueId().map(row -> RowFactory.create(Integer.parseInt(String.valueOf(row._2)), row._1._1(), row._1._2(), row._1._3()));


//        JavaRDD<Row> patternRdd = patFreq.map(row->RowFactory.create(row.javaItems().toString(), row.freq(), row.freq() * 1.0 / totalNum)).filter(x->x.getLong(1) >= minFreq);
        StructType structType1 = DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("_record_id_", DataTypes.IntegerType, true),
                DataTypes.createStructField("pattern", DataTypes.StringType, true),
                DataTypes.createStructField("frequency", DataTypes.LongType, true),
                DataTypes.createStructField("support", DataTypes.DoubleType, true)
        });
        Dataset<Row> df2 = sparkSession.createDataFrame(patternRdd, structType1);

        if (isHive) {
            this.saveHiveTable(df2, targetTable2);
        } else {
            UtilTool.saveGreenplumTable(df2, targetTable2);
        }

        JavaRDD<AssociationRules.Rule<String>> rules = model.generateAssociationRules(minConfidence).toJavaRDD();

        JavaRDD<Row> ruleDF = rules.map(rule ->
                RowFactory.create(String.format("%s=>%s", rule.javaAntecedent().toString(), rule.javaConsequent().toString()), rule.confidence()));
        StructType structType = DataTypes.createStructType(new StructField[]{
                DataTypes.createStructField("association_rule", DataTypes.StringType, true),
                DataTypes.createStructField("confidence", DataTypes.DoubleType, true)
        });

        Dataset<Row> df = sparkSession.createDataFrame(ruleDF, structType);

        if (isHive) {
            this.saveHiveTable(df, targetTable);
        } else {
            UtilTool.saveGreenplumTable(df, targetTable);
        }

        Map<String, Object> inputKV = new HashMap<>();
        List<String> outputTables = new ArrayList<>();
        outputTables.add(targetTable2);
        outputTables.add(targetTable);

        String metaPath = String.format(UtilTool.RUN_META_PATH, algorithmName, uniqueKey);

        OutputResult outputResult = this.buildOutputResult(outputTables, 0, "", inputKV);

        if (isHive) {
            this.saveOutputResultForHive(sparkSession, outputResult, metaPath);
        } else {
            this.saveOutputResultForMysql(sparkSession, outputResult);
        }
        retResult = outputResult;
        tansactions.unpersist(false);
        return true;
    }
}